{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/finetuning/gradient/gradient_structured.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine Tuning Llama2 for Better Structured Outputs With Gradient and LlamaIndex\n",
    "\n",
    "In this notebook we show you how to fine-tune llama2-7b to be better at outputting structured outputs.\n",
    "\n",
    "We do this by using [gradient.ai](https://gradient.ai)\n",
    "\n",
    "This is similar in format to our [OpenAI Functions Fine-tuning Notebook](https://docs.llamaindex.ai/en/latest/examples/finetuning/openai_fine_tuning_functions.html).\n",
    "\n",
    "**NOTE**: This is an alternative to our repo/guide on fine-tuning llama2-7b with Modal: https://github.com/run-llama/modal_finetune_sql"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install llama-index-llms-gradient\n",
    "%pip install llama-index-llms-openai\n",
    "%pip install llama-index-readers-file\n",
    "%pip install llama-index-finetuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install llama-index gradientai -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from llama_index.llms.gradient import GradientBaseModelLLM\n",
    "from llama_index.finetuning import GradientFinetuneEngine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"GRADIENT_ACCESS_TOKEN\"] = os.getenv(\"GRADIENT_API_KEY\")\n",
    "os.environ[\"GRADIENT_WORKSPACE_ID\"] = \"<insert_workspace_id>\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fine-tuning Using GPT-4 Pydantic Programs\n",
    "\n",
    "In this section we show how to log inputs + GPT-4 generated outputs through our low-level Pydantic Program module. We use that dataset to fine-tune llama2."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pydantic import BaseModel\n",
    "\n",
    "\n",
    "class Album(BaseModel):\n",
    "    \"\"\"Data model for an album.\"\"\"\n",
    "\n",
    "    name: str\n",
    "    artist: str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.callbacks import CallbackManager, LlamaDebugHandler\n",
    "from llama_index.llms.openai import OpenAI\n",
    "from llama_index.llms.gradient import GradientBaseModelLLM\n",
    "from llama_index.core.program import LLMTextCompletionProgram\n",
    "from llama_index.core.output_parsers import PydanticOutputParser\n",
    "\n",
    "openai_handler = LlamaDebugHandler()\n",
    "openai_callback = CallbackManager([openai_handler])\n",
    "openai_llm = OpenAI(model=\"gpt-4\", callback_manager=openai_callback)\n",
    "\n",
    "gradient_handler = LlamaDebugHandler()\n",
    "gradient_callback = CallbackManager([gradient_handler])\n",
    "base_model_slug = \"llama2-7b-chat\"\n",
    "gradient_llm = GradientBaseModelLLM(\n",
    "    base_model_slug=base_model_slug,\n",
    "    max_tokens=300,\n",
    "    callback_manager=gradient_callback,\n",
    "    is_chat_model=True,\n",
    ")\n",
    "# HACK: set chat model\n",
    "from llama_index.core.llms import LLMMetadata\n",
    "\n",
    "# gradient_llm.metadata = LLMMetadata(\n",
    "#     context_window=1024,\n",
    "#     num_output=gradient_llm.max_tokens or 20,\n",
    "#     is_chat_model=True,\n",
    "#     is_function_calling_model=False,\n",
    "#     model_name=gradient_llm._model.id,\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# try running both through LLMTextCompletionProgram\n",
    "\n",
    "prompt_template_str = \"\"\"\\\n",
    "Generate an example album, with an artist and a list of songs. \\\n",
    "Using the movie {movie_name} as inspiration.\\\n",
    "\"\"\"\n",
    "openai_program = LLMTextCompletionProgram.from_defaults(\n",
    "    output_parser=PydanticOutputParser(Album),\n",
    "    prompt_template_str=prompt_template_str,\n",
    "    llm=openai_llm,\n",
    "    verbose=True,\n",
    ")\n",
    "gradient_program = LLMTextCompletionProgram.from_defaults(\n",
    "    output_parser=PydanticOutputParser(Album),\n",
    "    prompt_template_str=prompt_template_str,\n",
    "    llm=gradient_llm,\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = openai_program(movie_name=\"The Shining\")\n",
    "print(str(response))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp = openai_handler.get_llm_inputs_outputs()\n",
    "print(tmp[0][0].payload[\"messages\"][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(tmp[0][1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = gradient_program(movie_name=\"The Shining\")\n",
    "print(str(response))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp = gradient_handler.get_llm_inputs_outputs()\n",
    "print(tmp[0][0].payload[\"messages\"][0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining Pydantic Model + Program\n",
    "\n",
    "Here, we define the GPT-4 powered function calling program that will generate structured outputs into a Pydantic object (an Album)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.program import LLMTextCompletionProgram\n",
    "from pydantic import BaseModel\n",
    "from llama_index.llms.openai import OpenAI\n",
    "from llama_index.core.callbacks import GradientAIFineTuningHandler\n",
    "from llama_index.core.callbacks import CallbackManager\n",
    "from llama_index.core.output_parsers import PydanticOutputParser\n",
    "from typing import List\n",
    "\n",
    "\n",
    "class Song(BaseModel):\n",
    "    \"\"\"Data model for a song.\"\"\"\n",
    "\n",
    "    title: str\n",
    "    length_seconds: int\n",
    "\n",
    "\n",
    "class Album(BaseModel):\n",
    "    \"\"\"Data model for an album.\"\"\"\n",
    "\n",
    "    name: str\n",
    "    artist: str\n",
    "    songs: List[Song]\n",
    "\n",
    "\n",
    "finetuning_handler = GradientAIFineTuningHandler()\n",
    "callback_manager = CallbackManager([finetuning_handler])\n",
    "\n",
    "llm_gpt4 = OpenAI(model=\"gpt-4\", callback_manager=callback_manager)\n",
    "\n",
    "\n",
    "prompt_template_str = \"\"\"\\\n",
    "Generate an example album, with an artist and a list of songs. \\\n",
    "Using the movie {movie_name} as inspiration.\\\n",
    "\"\"\"\n",
    "openai_program = LLMTextCompletionProgram.from_defaults(\n",
    "    output_parser=PydanticOutputParser(Album),\n",
    "    prompt_template_str=prompt_template_str,\n",
    "    llm=llm_gpt4,\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Log Inputs/Outputs\n",
    "\n",
    "We define some sample movie names as inputs and log the outputs through the function calling program."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NOTE: we need >= 10 movies to use Gradient fine-tuning\n",
    "movie_names = [\n",
    "    \"The Shining\",\n",
    "    \"The Departed\",\n",
    "    \"Titanic\",\n",
    "    \"Goodfellas\",\n",
    "    \"Pretty Woman\",\n",
    "    \"Home Alone\",\n",
    "    \"Caged Fury\",\n",
    "    \"Edward Scissorhands\",\n",
    "    \"Total Recall\",\n",
    "    \"Ghost\",\n",
    "    \"Tremors\",\n",
    "    \"RoboCop\",\n",
    "    \"Rocky V\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm\n",
    "\n",
    "for movie_name in tqdm(movie_names):\n",
    "    output = openai_program(movie_name=movie_name)\n",
    "    print(output.json())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "events = finetuning_handler.get_finetuning_events()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "events"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wrote 14 examples to mock_finetune_songs.jsonl\n"
     ]
    }
   ],
   "source": [
    "finetuning_handler.save_finetuning_events(\"mock_finetune_songs.jsonl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cat mock_finetune_songs.jsonl"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fine-tune on the Dataset\n",
    "\n",
    "We now define a fine-tuning engine and fine-tune on the mock dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define base model\n",
    "base_model_slug = \"llama2-7b-chat\"\n",
    "base_llm = GradientBaseModelLLM(\n",
    "    base_model_slug=base_model_slug, max_tokens=500, is_chat_model=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.finetuning import GradientFinetuneEngine\n",
    "\n",
    "finetune_engine = GradientFinetuneEngine(\n",
    "    base_model_slug=base_model_slug,\n",
    "    # model_adapter_id='805c6fd6-daa8-4fc8-a509-bebb2f2c1024_model_adapter',\n",
    "    name=\"movies_structured\",\n",
    "    data_path=\"mock_finetune_songs.jsonl\",\n",
    "    verbose=True,\n",
    "    max_steps=200,\n",
    "    batch_size=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1f810f84-c4b8-43b0-b6b0-10d2cbdaf92f_model_adapter'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "finetune_engine.model_adapter_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# asdjust epochs as necessary\n",
    "epochs = 2\n",
    "for i in range(epochs):\n",
    "    print(f\"** EPOCH {i} **\")\n",
    "    finetune_engine.finetune()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ft_llm = finetune_engine.get_finetuned_model(\n",
    "    max_tokens=500, is_chat_model=True\n",
    ")\n",
    "\n",
    "# # NOTE: same as doing the following\n",
    "from llama_index.llms.gradient import GradientModelAdapterLLM\n",
    "\n",
    "# ft_llm = GradientModelAdapterLLM(\n",
    "#     model_adapter_id=finetune_engine.model_adapter_id,\n",
    "#     max_tokens=500\n",
    "# )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Try it Out! \n",
    "\n",
    "We obtain the fine-tuned LLM and use it with the Pydantic program."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# try a slightly modified prompt_template_str\n",
    "new_prompt_template_str = \"\"\"\\\n",
    "Generate an example album, with an artist and a list of songs. \\\n",
    "Using the movie {movie_name} as inspiration.\\\n",
    "\n",
    "Please only generate one album.\n",
    "\"\"\"\n",
    "\n",
    "gradient_program = LLMTextCompletionProgram.from_defaults(\n",
    "    output_parser=PydanticOutputParser(Album),\n",
    "    # prompt_template_str=prompt_template_str,\n",
    "    prompt_template_str=new_prompt_template_str,\n",
    "    llm=ft_llm,\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Album(name='Wiseguy Melodies', artist='Tommy DeVito & The Gangsters', songs=[Song(title='Life in the Fast Lane', length_seconds=210), Song(title='Money and Power', length_seconds=240), Song(title='Goodfellas', length_seconds=270), Song(title='Betrayal', length_seconds=200), Song(title='Downfall', length_seconds=180)])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gradient_program(movie_name=\"Goodfellas\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gradient_program(movie_name=\"Chucky\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# you wouldn't get this with normal llama2-7b!\n",
    "base_gradient_program = LLMTextCompletionProgram.from_defaults(\n",
    "    output_parser=PydanticOutputParser(Album),\n",
    "    prompt_template_str=prompt_template_str,\n",
    "    llm=base_llm,\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# throws an error\n",
    "base_gradient_program(movie_name=\"Goodfellas\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fine-tuning Structured Outputs through a RAG System\n",
    "\n",
    "A use case of function calling is to get structured outputs through a RAG system.\n",
    "\n",
    "Here we show how to create a training dataset of context-augmented inputs + structured outputs over an unstructured document. We can then fine-tune the LLM and plug it into a RAG system to perform retrieval + output extraction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir data && wget --user-agent \"Mozilla\" \"https://arxiv.org/pdf/2307.09288.pdf\" -O \"data/llama2.pdf\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pydantic import Field\n",
    "from typing import List\n",
    "\n",
    "\n",
    "class Citation(BaseModel):\n",
    "    \"\"\"Citation class.\"\"\"\n",
    "\n",
    "    author: str = Field(\n",
    "        ..., description=\"Inferred first author (usually last name\"\n",
    "    )\n",
    "    year: int = Field(..., description=\"Inferred year\")\n",
    "    desc: str = Field(\n",
    "        ...,\n",
    "        description=(\n",
    "            \"Inferred description from the text of the work that the author is\"\n",
    "            \" cited for\"\n",
    "        ),\n",
    "    )\n",
    "\n",
    "\n",
    "class Response(BaseModel):\n",
    "    \"\"\"List of author citations.\n",
    "\n",
    "    Extracted over unstructured text.\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    citations: List[Citation] = Field(\n",
    "        ...,\n",
    "        description=(\n",
    "            \"List of author citations (organized by author, year, and\"\n",
    "            \" description).\"\n",
    "        ),\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Data + Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.readers.file import PyMuPDFReader\n",
    "from llama_index.core import Document\n",
    "from llama_index.core.node_parser import SimpleNodeParser\n",
    "from pathlib import Path\n",
    "from llama_index.core.callbacks import GradientAIFineTuningHandler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = PyMuPDFReader()\n",
    "docs0 = loader.load(file_path=Path(\"./data/llama2.pdf\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "doc_text = \"\\n\\n\".join([d.get_content() for d in docs0])\n",
    "metadata = {\n",
    "    \"paper_title\": \"Llama 2: Open Foundation and Fine-Tuned Chat Models\"\n",
    "}\n",
    "docs = [Document(text=doc_text, metadata=metadata)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chunk_size = 1024\n",
    "node_parser = SimpleNodeParser.from_defaults(chunk_size=chunk_size)\n",
    "nodes = node_parser.get_nodes_from_documents(docs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "89"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup GPT-4 context - to generate \"ground-truth\" data given queries\n",
    "finetuning_handler = GradientAIFineTuningHandler()\n",
    "callback_manager = CallbackManager([finetuning_handler])\n",
    "llm_gpt4 = OpenAI(model=\"gpt-4-0613\", temperature=0.3)\n",
    "llm_gpt4.pydantic_program_mode = \"llm\"\n",
    "\n",
    "\n",
    "# setup gradient.ai context\n",
    "base_model_slug = \"llama2-7b-chat\"\n",
    "base_llm = GradientBaseModelLLM(\n",
    "    base_model_slug=base_model_slug, max_tokens=500, is_chat_model=True\n",
    ")\n",
    "base_llm.pydantic_program_mode = \"llm\"\n",
    "\n",
    "# setup eval context (for question generation)\n",
    "eval_llm = OpenAI(model=\"gpt-4-0613\", temperature=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate Dataset\n",
    "\n",
    "Here we show how to generate a training dataset over these unstructured chunks/nodes.\n",
    "\n",
    "We generate questions to extract citations over different context. We run these questions through a GPT-4 RAG pipeline, extract structured outputs, and log inputs/outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:27<00:00,  1.41it/s]\n"
     ]
    }
   ],
   "source": [
    "# setup dataset generator\n",
    "from llama_index.core.evaluation import DatasetGenerator\n",
    "from llama_index.core import SummaryIndex\n",
    "from llama_index.core import PromptTemplate\n",
    "from tqdm.notebook import tqdm\n",
    "from tqdm.asyncio import tqdm_asyncio\n",
    "\n",
    "\n",
    "fp = open(\"data/qa_pairs.jsonl\", \"w\")\n",
    "\n",
    "question_gen_prompt = PromptTemplate(\n",
    "    \"\"\"\n",
    "{query_str}\n",
    "\n",
    "Context:\n",
    "{context_str}\n",
    "\n",
    "Questions:\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "question_gen_query = \"\"\"\\\n",
    "Snippets from a research paper is given below. It contains citations.\n",
    "Please generate questions from the text asking about these citations.\n",
    "\n",
    "For instance, here are some sample questions:\n",
    "Which citations correspond to related works on transformer models? \n",
    "Tell me about authors that worked on advancing RLHF.\n",
    "Can you tell me citations corresponding to all computer vision works? \\\n",
    "\"\"\"\n",
    "\n",
    "qr_pairs = []\n",
    "node_questions_tasks = []\n",
    "for idx, node in enumerate(nodes[:39]):\n",
    "    num_questions = 1  # change this number to increase number of nodes\n",
    "    dataset_generator = DatasetGenerator(\n",
    "        [node],\n",
    "        question_gen_query=question_gen_query,\n",
    "        text_question_template=question_gen_prompt,\n",
    "        llm=eval_llm,\n",
    "        metadata_mode=\"all\",\n",
    "        num_questions_per_chunk=num_questions,\n",
    "    )\n",
    "\n",
    "    task = dataset_generator.agenerate_questions_from_nodes(num=num_questions)\n",
    "    node_questions_tasks.append(task)\n",
    "node_questions_lists = await tqdm_asyncio.gather(*node_questions_tasks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "39"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(node_questions_lists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Which citations are mentioned in the section about RLHF Results?']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node_questions_lists[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# [optional] save\n",
    "import pickle\n",
    "\n",
    "pickle.dump(node_questions_lists, open(\"llama2_questions.pkl\", \"wb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# [optional] load questions\n",
    "node_questions_lists = pickle.load(open(\"llama2_questions.pkl\", \"rb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import VectorStoreIndex\n",
    "\n",
    "gpt4_index = VectorStoreIndex(nodes[:39], callback_manager=callback_manager)\n",
    "gpt4_query_engine = gpt4_index.as_query_engine(\n",
    "    output_cls=Response, llm=llm_gpt4, similarity_top_k=1\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from json import JSONDecodeError\n",
    "\n",
    "for idx, node in enumerate(tqdm(nodes[:39])):\n",
    "    node_questions_0 = node_questions_lists[idx]\n",
    "    for question in node_questions_0:\n",
    "        try:\n",
    "            # note: we don't need to use response, events are logged through fine-tuning handler\n",
    "            gpt4_query_engine.query(question)\n",
    "        except Exception as e:\n",
    "            print(f\"Error for question {question}, {repr(e)}\")\n",
    "            pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wrote 39 examples to llama2_citation_events.jsonl\n"
     ]
    }
   ],
   "source": [
    "finetuning_handler.save_finetuning_events(\"llama2_citation_events.jsonl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup Fine-tuning\n",
    "\n",
    "We kick off fine-tuning over the generated dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.finetuning import GradientFinetuneEngine\n",
    "\n",
    "finetune_engine = GradientFinetuneEngine(\n",
    "    base_model_slug=base_model_slug,\n",
    "    # model_adapter_id='23a71710-47b3-43be-9be2-58a3efbccf2b_model_adapter',\n",
    "    name=\"llama2_structured\",\n",
    "    data_path=\"llama2_citation_events.jsonl\",\n",
    "    verbose=True,\n",
    "    max_steps=200,\n",
    "    batch_size=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'23a71710-47b3-43be-9be2-58a3efbccf2b_model_adapter'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# save this for future runs\n",
    "finetune_engine.model_adapter_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# asdjust epochs as necessary\n",
    "epochs = 2\n",
    "for i in range(epochs):\n",
    "    print(f\"** EPOCH {i} **\")\n",
    "    finetune_engine.finetune()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Use within RAG Pipeline\n",
    "\n",
    "Let's plug the fine-tuned LLM into a full RAG pipeline that outputs structured outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ft_llm = finetune_engine.get_finetuned_model(max_tokens=500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import VectorStoreIndex\n",
    "\n",
    "vector_index = VectorStoreIndex(nodes)\n",
    "query_engine = vector_index.as_query_engine(\n",
    "    output_cls=Response, llm=ft_llm, similarity_top_k=1\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup baseline as well\n",
    "base_index = VectorStoreIndex(nodes)\n",
    "base_query_engine = base_index.as_query_engine(\n",
    "    output_cls=Response, llm=base_llm, similarity_top_k=1\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "query_str = \"Which citations are mentioned in the section about RLHF Results?\"\n",
    "# query_str = \"\"\"\\\n",
    "# Which citation corresponds to the concept of collecting data that represents \\\n",
    "# empirically sampled human preferences in RLHF?\\\n",
    "# \"\"\"\n",
    "# query_str = \"Which citations in the paper discuss the development and release of Llama 2?\"\n",
    "# query_str = \"Which citations are mentioned in the section on RLHF Results?\"\n",
    "# query_str = \"Which citation discusses the carbon output related to the production of AI hardware?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = query_engine.query(query_str)\n",
    "print(str(response))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take a look at sources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# view sources\n",
    "print(response.source_nodes[0].get_content())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's compare against the baseline (the base llama2-7b model). Notice that the query engine throws an error! "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# throws an error!\n",
    "base_response = base_query_engine.query(query_str)\n",
    "print(str(base_response))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As a reference, let's also compare against gpt-4."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# as a reference, take a look at GPT-4 response\n",
    "gpt4_response = gpt4_query_engine.query(query_str)\n",
    "print(str(gpt4_response))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama_index_v2",
   "language": "python",
   "name": "llama_index_v2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
